import numpy as np
from matplotlib import pyplot as plt

from sklearn.tree import DecisionTreeClassifier
from sklearn import tree


if __name__ == '__main__':
    # 年龄，收入，长相，公务员
    X_train = np.array([[28, 1, 2, 1],
                        [35, 2, 1, 0],
                        [26, 1, 2, 0],
                        [31, 2, 1, 1],
                        [26, 1, 2, 1],
                        [25, 0, 1, 0],
                        [30, 2, 0, 1]])
    Y_train = np.array([1, 0, 0, 0, 1, 0, 0])
    clf = DecisionTreeClassifier(random_state=0)
    clf.fit(X_train, Y_train)

    tree.plot_tree(clf)
    plt.show()
